import torch
# import sys
import torch.nn as nn
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(43)
torch.cuda.manual_seed_all(43)
n_epochs = 5000 # number of epochs
theta = 30
hidden_dim = 10
n_layers = 4
n_params_per_layer = 10
n_params = n_layers * n_params_per_layer
MODEL_PATH = "model_pnn.pth"

p_vector0 = torch.normal(0, 1, size=(1, hidden_dim)).to(device)
p_vector0 = p_vector0 / (p_vector0.norm(2, 1, keepdim=True) + 1e-4)

# Define the second vector
p_vector1 = torch.normal(0, 1, size=(1, hidden_dim)).to(device)
p_vector1 = p_vector1 / (p_vector1.norm(2, 1, keepdim=True) + 1e-4)

# Define the third vector
p_vector2 = torch.normal(0, 1, size=(1, hidden_dim )).to(device)
p_vector2 = p_vector2 / (p_vector2.norm(2, 1, keepdim=True) + 1e-4)

p_vector3 = torch.normal(0, 1, size=(1, hidden_dim)).to(device)
p_vector3 = p_vector3 / (p_vector3.norm(2, 1, keepdim=True) + 1e-4)

def function(x):
    return x ** 2
cos = nn.CosineSimilarity(dim=1, eps=1e-6)

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.params = torch.nn.ParameterList([nn.Parameter(torch.randn(1)*0.1) for i in range(n_params)])
        n_inputs = hidden_dim
        self.fc1 = Layer(3, n_inputs,self.params[0:n_params_per_layer])
        self.fc2 = Layer(n_inputs, n_inputs,self.params[n_params_per_layer:2*n_params_per_layer])
        self.fc3 = Layer(n_inputs, n_inputs,self.params[2*n_params_per_layer:3*n_params_per_layer])
        self.fc4 = Layer(n_inputs, 1,self.params[3*n_params_per_layer:4*n_params_per_layer])
        self.softplus = nn.Softplus()
        self.layers = [self.fc1, self.fc2, self.fc3, self.fc4]

    def predict(self, x):
        n_y = 1000
        x = x.repeat(n_y, 1)
        y = (torch.linspace(-0.5, 1.5, n_y).view(n_y, 1).to(device))  # generate 50 random y coordinates
        x_which = torch.cat((x, y), 1)
        h = x_which
        goodness_per_label = []
        for label in [0.0, 1.0]:
            if x_which.shape[1] == 2:
                x_which = torch.cat((x_which, torch.ones_like((x_which[:, 0].unsqueeze(1))) * label), 1)
            else:
                x_which[:, -1] = torch.ones_like((x_which[:, 0])) * label

            goodness = []
            for k, layer in enumerate(self.layers):
                # print(k)
                if k == 0:
                    # first layer
                    g = layer(x_which)
                    g = g / (g.norm(dim=1, keepdim=True) + 1e-6)
                    # print("g shape", g.shape)
                    goodness += [cos(g, p_vector0.repeat(g.shape[0], 1))]
                    # assert goodness[0].shape == g.shape
                if k == 1:
                    # second layer
                    g = layer(g)
                    g = g / (g.norm(dim=1, keepdim=True) + 1e-6)
                    # print("g shape", g.shape)
                    goodness += [cos(g, p_vector1.repeat(g.shape[0], 1))]
                if k == 2:
                    # third layer
                    g = layer(g)
                    g = g / (g.norm(dim=1, keepdim=True) + 1e-6)
                    # print("g shape", g.shape)
                    goodness += [cos(g, p_vector2.repeat(g.shape[0], 1))]
                if k == 3:
                    # third layer
                    g = layer(g)
                    g = g / (g.norm(dim=1, keepdim=True) + 1e-6)
                    # print("g shape", g.shape)
                    goodness += [cos(g, p_vector3.repeat(g.shape[0], 1))]
 
            goodness_per_label += [sum(goodness).unsqueeze(1)]


        goodness_tensor = torch.cat(goodness_per_label, 1)
        # print("Goodness per label", goodness_tensor.shape)
        mask = goodness_tensor[:,1] < goodness_tensor[:,0]
        return y[mask]


    def train(self,dataloader, n):
        k = 0
        for i, (x, y) in enumerate(dataloader):
            x, y = x.to(device), y.to(device)
            # x is tuple of x and y coordinates and y is label for real or fake data
            x_neg = torch.cat((x[:n * 20], y[:n * 20].unsqueeze(1)), 1)  # positive data
            x_pos = torch.cat((x[n * 20:], y[n * 20:].unsqueeze(1)), 1)  # negative data
            # x_neg and x_pos are inverted
            h_pos, h_neg = x_pos, x_neg
            for layer in self.layers:
                h_pos, h_neg, loss = layer.train_layer(h_pos, h_neg, k)
                k += 1
                print("Layer", k, "Loss", loss)
class Own_Sin(nn.Module):
    def forward(self, x):
        return 5*torch.sin(x) + 5*torch.cos(x)
class Layer(nn.Module):
    def __init__(self,in_features, out_features, params):
        super(Layer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.theta = theta
        self.params = params
        self.layer = nn.Linear(in_features+len(params), out_features).to(device)
        # self.main = nn.Sequential(self.layer, nn.GELU())
        self.main = nn.Sequential(self.layer, Own_Sin())
        self.iterations = n_epochs
        self.opt = torch.optim.Adam(self.params, lr=0.001)

    def forward(self, x):
        # x = x.view(-1, self.in_features)
        temp = torch.zeros(len(self.params)).to(device)
        for i in range(len(self.params)):
            temp[i]= self.params[i]
        input = torch.zeros((x.shape[0], len(self.params)+x.shape[1])).to(device)
        input[:,range(x.shape[1])] = x
        input[:,range(x.shape[1], len(self.params)+x.shape[1])] = temp
        return self.main(input)

    def goodness(self, x_pos, x_neg, k):
        h_pos = self.forward(x_pos)  # positive always coz of abs
        h_neg = self.forward(x_neg)  # negative always coz of abs
        if k == 0:
            g_pos = cos(h_pos, p_vector0.repeat(h_pos.shape[0], 1))
            g_neg = cos(h_neg, p_vector0.repeat(h_neg.shape[0], 1))
        elif k == 1:
            g_pos = cos(h_pos, p_vector1.repeat(h_pos.shape[0], 1))
            g_neg = cos(h_neg, p_vector1.repeat(h_neg.shape[0], 1))
        elif k == 2:
            g_pos = cos(h_pos, p_vector2.repeat(h_pos.shape[0], 1))
            g_neg = cos(h_neg, p_vector2.repeat(h_neg.shape[0], 1))
        elif k == 3:
            g_pos = cos(h_pos, p_vector3.repeat(h_pos.shape[0], 1))
            g_neg = cos(h_neg, p_vector3.repeat(h_neg.shape[0], 1))
        return g_pos, g_neg

    def train_layer(self, h_pos, h_neg, k):
        self.running_loss = 0.0
        for i in range(self.iterations):

            g_pos, g_neg = self.goodness(h_pos,h_neg,k)
            delta = g_pos - g_neg
            loss = (torch.log(1 + torch.exp(-self.theta * delta))).mean() + delta.mean()**2

            self.opt.zero_grad()
            loss.backward()
            self.opt.step()
            self.running_loss += loss
        return (
            self.forward(h_pos).detach(),
            self.forward(h_neg).detach(),
            self.running_loss / self.iterations,
        )

def weights_init(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.normal_(m.weight, mean=0.0, std=30.0)
        torch.nn.init.normal_(m.bias, mean=0.0, std=30.0)

def create_dataset(x, n):
    n_intol = 10 # no of points inside the tolerance region
    n_outtol = 10 # no of points outside the tolerance region
    y_noised = function(x).view(1, n).repeat(n_intol, 1) + torch.randn(n_intol, n) * tol #0.1
    x_intol = x.repeat(n_intol, 1)
    in_tol_pos = torch.cat((x_intol, y_noised, torch.ones_like(x_intol, dtype=torch.float)), 0).view(3, n_intol, n)
    in_tol_neg = torch.cat((x_intol, y_noised, torch.ones_like(x_intol, dtype=torch.float) * 0.0), 0).view(3, n_intol, n)
    x_outtol = x.repeat(n_outtol,1)
    y_noised = function(x).view(1, n).repeat(n_outtol, 1) + torch.randn(n_outtol, n) * tol #0.1
    y_high = function(x_outtol[:n_outtol//2]) + tol	# upper tolerance band
    y_low = function(x_outtol[:n_outtol//2]) - tol  # lower tolerance band
    y_above = torch.zeros_like(y_noised[:n_outtol//2])
    y_below = torch.zeros_like(y_noised[:n_outtol//2])
    y_max = y_noised.max().repeat(n_outtol//2, n) + 2
    y_min = y_noised.min().repeat(n_outtol//2, n) - 2
    for i in range(n_outtol//2):
        y_above[i] = (y_max[i] - y_high[i]) * torch.rand(1, n) + y_high[i]
        y_below[i] = (y_low[i] - y_min[i]) * torch.rand(1, n) + y_min[i]
    y_out_tol = torch.cat(
        (y_above, y_below), 0
    )  # shape [10,n] 10 points above and below tolerance band
    out_tol_pos = torch.cat(
        (x_outtol, y_out_tol, torch.ones_like(x_outtol, dtype=torch.float) * 0.0), 0
    ).view(3, n_outtol, n)  # -1 os correct label for out of tol data
    out_tol_neg = torch.cat(
        (x_outtol, y_out_tol, torch.ones_like(x_outtol, dtype=torch.float)), 0
    ).view(3, n_outtol, n)
    positive_data = torch.cat((in_tol_pos, out_tol_pos), 1).flatten(1,2)
    negative_data = torch.cat((in_tol_neg, out_tol_neg), 1).flatten(1,2)
    return positive_data, negative_data

if __name__ == "__main__":
    n = 20
    x_start = -1
    x_end = 1
    x = torch.linspace(x_start,x_end, n).view(1,n)
    tol = 0.8
    pos, neg = create_dataset(x, n)
    dataset = torch.cat((pos, neg), 1)
    dataset = torch.utils.data.TensorDataset(dataset[:2, :].T, dataset[2, :])
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=800, shuffle=False)
    model = NeuralNetwork()

    model.apply(weights_init)
    # model.train(dataloader, n)
    # torch.save(model, MODEL_PATH)
    model = torch.load(MODEL_PATH,weights_only=False)

    x_test = torch.linspace(x_start, x_end, 20).view(20, 1).to(device)
    y_pred = torch.zeros(20, 1).to(device)
    for i in range(20):
        y_pred[i] = model.predict(x_test[i]).mean()

    y_org = function(x_test).to(device)
    print("MSE: ", torch.mean((y_pred - y_org)**2))
    plt.figure(figsize=(6, 5)) 
    plt.plot(x_test.cpu().numpy(), y_pred.cpu().numpy(), label='Predicted Points', color='blue')
    
    plt.scatter(x_test.cpu().numpy(),function(x_test).cpu().numpy(), label='Original', linestyle='dashed',color='r')
    plt.title('Predicted Points vs Original Function')
    plt.xlabel('x')
    plt.ylabel('y')
    plt.legend()
    plt.savefig('pnn_fig_1000.pdf')
    plt.show()